热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

DGLRDKit|基于AttentiveFP可视化训练模型原子权重

DGL具有许多用于化学信息学、药物与生物信息学任务的函数。DGL开发人员提供了用于可视化训练模型原子权重的代码。使用AttentiveFP构建模型后,可以可视化给定

DGL具有许多用于化学信息学、药物与生物信息学任务的函数。

DGL开发人员提供了用于可视化训练模型原子权重的代码。使用Attentive FP构建模型后,可以可视化给定分子的原子权重,意味着每个原子对目标值的贡献量。




基于Attentive FP可视化训练模型原子权重


环境准备


  • PyTorch:深度学习框架
  • DGL:基于PyTorch的库,支持深度学习以处理图形
  • RDKit:用于构建分子图并从字符串表示形式绘制结构式
  • MDTraj:用于分子动力学轨迹分析的开源库



导入库

%matplotlib inline
import matplotlib.pyplot as plt
import os
from rdkit import Chem
from rdkit import RDPathsimport dgl
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from dgl import model_zoofrom dgl.data.chem.utils import mol_to_complete_graph, mol_to_bigraphfrom dgl.data.chem.utils import atom_type_one_hot
from dgl.data.chem.utils import atom_degree_one_hot
from dgl.data.chem.utils import atom_formal_charge
from dgl.data.chem.utils import atom_num_radical_electrons
from dgl.data.chem.utils import atom_hybridization_one_hot
from dgl.data.chem.utils import atom_total_num_H_one_hot
from dgl.data.chem.utils import one_hot_encoding
from dgl.data.chem import CanonicalAtomFeaturizer
from dgl.data.chem import CanonicalBondFeaturizer
from dgl.data.chem import ConcatFeaturizer
from dgl.data.chem import BaseAtomFeaturizer
from dgl.data.chem import BaseBondFeaturizerfrom dgl.data.chem import one_hot_encoding
from dgl.data.utils import split_datasetfrom functools import partial
from sklearn.metrics import roc_auc_score

代码来源于dgl/example

DGL开发人员提供了用于可视化训练模型原子权重的代码。

使用Attentive FP构建模型后,可以可视化给定分子的原子权重,意味着每个原子对目标值的贡献量。

 

def chirality(atom):try:return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \[atom.HasProp('_ChiralityPossible')]except:return [False, False] + [atom.HasProp('_ChiralityPossible')]def collate_molgraphs(data):"""Batching a list of datapoints for dataloader.Parameters----------data : list of 3-tuples or 4-tuples.Each tuple is for a single datapoint, consisting ofa SMILES, a DGLGraph, all-task labels and optionallya binary mask indicating the existence of labels.Returns-------smiles : listList of smilesbg : BatchedDGLGraphBatched DGLGraphslabels : Tensor of dtype float32 and shape (B, T)Batched datapoint labels. B is len(data) andT is the number of total tasks.masks : Tensor of dtype float32 and shape (B, T)Batched datapoint binary mask, indicating theexistence of labels. If binary masks are notprovided, return a tensor with ones."""assert len(data[0]) in [3, 4], \'Expect the tuple to be of length 3 or 4, got {:d}'.format(len(data[0]))if len(data[0]) == 3:smiles, graphs, labels = map(list, zip(*data))masks = Noneelse:smiles, graphs, labels, masks = map(list, zip(*data))bg = dgl.batch(graphs)bg.set_n_initializer(dgl.init.zero_initializer)bg.set_e_initializer(dgl.init.zero_initializer)labels = torch.stack(labels, dim=0)if masks is None:masks = torch.ones(labels.shape)else:masks = torch.stack(masks, dim=0)return smiles, bg, labels, masksatom_featurizer = BaseAtomFeaturizer({'hv': ConcatFeaturizer([partial(atom_type_one_hot, allowable_set=['B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],encode_unknown=True),partial(atom_degree_one_hot, allowable_set=list(range(6))),atom_formal_charge, atom_num_radical_electrons,partial(atom_hybridization_one_hot, encode_unknown=True),lambda atom: [0], # A placeholder for aromatic information,atom_total_num_H_one_hot, chirality],)})
bond_featurizer = BaseBondFeaturizer({'he': lambda bond: [0 for _ in range(10)]})train_mols = Chem.SDMolSupplier('solubility.train.sdf')
train_smi =[Chem.MolToSmiles(m) for m in train_mols]
train_sol = torch.tensor([float(mol.GetProp('SOL')) for mol in train_mols]).reshape(-1,1)test_mols = Chem.SDMolSupplier('solubility.test.sdf')
test_smi = [Chem.MolToSmiles(m) for m in test_mols]
test_sol = torch.tensor([float(mol.GetProp('SOL')) for mol in test_mols]).reshape(-1,1)train_graph =[mol_to_bigraph(mol,node_featurizer=atom_featurizer, edge_featurizer=bond_featurizer) for mol in train_mols]test_graph =[mol_to_bigraph(mol,node_featurizer=atom_featurizer, edge_featurizer=bond_featurizer) for mol in test_mols]def run_a_train_epoch(n_epochs, epoch, model, data_loader,loss_criterion, optimizer):model.train()total_loss = 0losses = []for batch_id, batch_data in enumerate(data_loader):batch_datasmiles, bg, labels, masks = batch_dataif torch.cuda.is_available():bg.to(torch.device('cuda:0'))labels = labels.to('cuda:0')masks = masks.to('cuda:0')prediction = model(bg, bg.ndata['hv'], bg.edata['he'])loss = (loss_criterion(prediction, labels)*(masks != 0).float()).mean()#loss = loss_criterion(prediction, labels)#print(loss.shape)optimizer.zero_grad()loss.backward()optimizer.step()losses.append(loss.data.item())#total_score = np.mean(train_meter.compute_metric('rmse'))total_score = np.mean(losses)print('epoch {:d}/{:d}, training {:.4f}'.format( epoch + 1, n_epochs, total_score))return total_scoremodel = model_zoo.chem.AttentiveFP(node_feat_size=39,edge_feat_size=10,num_layers=2,num_timesteps=2,graph_feat_size=200,output_size=1,dropout=0.2)train_loader = DataLoader(dataset=list(zip(train_smi, train_graph, train_sol)), batch_size=128, collate_fn=collate_molgraphs)
test_loader = DataLoader(dataset=list(zip(test_smi, test_graph, test_sol)), batch_size=128, collate_fn=collate_molgraphs)loss_fn = nn.MSELoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr=10 ** (-2.5), weight_decay=10 ** (-5.0),)
n_epochs = 100
epochs = []
scores = []
for e in range(n_epochs):score = run_a_train_epoch(n_epochs, e, model, train_loader, loss_fn, optimizer)epochs.append(e)scores.append(score)
model.eval()

导入用于分子可视化依赖库

import copy
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import SVG
from IPython.display import display
import matplotlib
import matplotlib.cm as cm

定义可视化函数


  • 代码来源于DGL库。
  • DGL模型具有get_node_weight选项,该选项返回图形的node_weight。该模型具有两层GRU,因此以下代码我将0用作时间步长,因此时间步长必须为0或1。

def drawmol(idx, dataset, timestep):smiles, graph, _ = dataset[idx]print(smiles)bg = dgl.batch([graph])atom_feats, bond_feats = bg.ndata['hv'], bg.edata['he']if torch.cuda.is_available():print('use cuda')bg.to(torch.device('cuda:0'))atom_feats = atom_feats.to('cuda:0')bond_feats = bond_feats.to('cuda:0')_, atom_weights = model(bg, atom_feats, bond_feats, get_node_weight=True)assert timestep

绘制测试数据集分子

该模型预测溶解度,颜色表示红色是溶解度的积极影响,蓝色是负面影响。

target = test_loader.dataset
for i in range(len(target)):mol, aw, svg = drawmol(i, target, 0)display(SVG(svg))

。。。。。 




参考资料

1. https://github.com/dmlc/dgl/tree/master/apps/life_sci

2. https://github.com/dmlc/dgl/blob/master/python/dgl/model_zoo/chem/attentive_fp.py

3. https://pubs.acs.org/doi/full/10.1021/acs.jcim.9b00387

 


推荐阅读
  • 本文介绍如何使用阿里云的fastjson库解析包含时间戳、IP地址和参数等信息的JSON格式文本,并进行数据处理和保存。 ... [详细]
  • CentOS7源码编译安装MySQL5.6
    2019独角兽企业重金招聘Python工程师标准一、先在cmake官网下个最新的cmake源码包cmake官网:https:www.cmake.org如此时最新 ... [详细]
  • 本文详细介绍了Java编程语言中的核心概念和常见面试问题,包括集合类、数据结构、线程处理、Java虚拟机(JVM)、HTTP协议以及Git操作等方面的内容。通过深入分析每个主题,帮助读者更好地理解Java的关键特性和最佳实践。 ... [详细]
  • UNP 第9章:主机名与地址转换
    本章探讨了用于在主机名和数值地址之间进行转换的函数,如gethostbyname和gethostbyaddr。此外,还介绍了getservbyname和getservbyport函数,用于在服务器名和端口号之间进行转换。 ... [详细]
  • PHP 过滤器详解
    本文深入探讨了 PHP 中的过滤器机制,包括常见的 $_SERVER 变量、filter_has_var() 函数、filter_id() 函数、filter_input() 函数及其数组形式、filter_list() 函数以及 filter_var() 和其数组形式。同时,详细介绍了各种过滤器的用途和用法。 ... [详细]
  • 主要用了2个类来实现的,话不多说,直接看运行结果,然后在奉上源代码1.Index.javaimportjava.awt.Color;im ... [详细]
  • 深入理解 SQL 视图、存储过程与事务
    本文详细介绍了SQL中的视图、存储过程和事务的概念及应用。视图为用户提供了一种灵活的数据查询方式,存储过程则封装了复杂的SQL逻辑,而事务确保了数据库操作的完整性和一致性。 ... [详细]
  • 前言--页数多了以后需要指定到某一页(只做了功能,样式没有细调)html ... [详细]
  • 本文深入探讨了 Java 中的 Serializable 接口,解释了其实现机制、用途及注意事项,帮助开发者更好地理解和使用序列化功能。 ... [详细]
  • DNN Community 和 Professional 版本的主要差异
    本文详细解析了 DotNetNuke (DNN) 的两种主要版本:Community 和 Professional。通过对比两者的功能和附加组件,帮助用户选择最适合其需求的版本。 ... [详细]
  • ImmutableX Poised to Pioneer Web3 Gaming Revolution
    ImmutableX is set to spearhead the evolution of Web3 gaming, with its innovative technologies and strategic partnerships driving significant advancements in the industry. ... [详细]
  • 题目Link题目学习link1题目学习link2题目学习link3%%%受益匪浅!-----&# ... [详细]
  • 深入解析 Apache Shiro 安全框架架构
    本文详细介绍了 Apache Shiro,一个强大且灵活的开源安全框架。Shiro 专注于简化身份验证、授权、会话管理和加密等复杂的安全操作,使开发者能够更轻松地保护应用程序。其核心目标是提供易于使用和理解的API,同时确保高度的安全性和灵活性。 ... [详细]
  • 深入解析 Spring Security 用户认证机制
    本文将详细介绍 Spring Security 中用户登录认证的核心流程,重点分析 AbstractAuthenticationProcessingFilter 和 AuthenticationManager 的工作原理。通过理解这些组件的实现,读者可以更好地掌握 Spring Security 的认证机制。 ... [详细]
  • 本文深入探讨了HTTP请求和响应对象的使用,详细介绍了如何通过响应对象向客户端发送数据、处理中文乱码问题以及常见的HTTP状态码。此外,还涵盖了文件下载、请求重定向、请求转发等高级功能。 ... [详细]
author-avatar
手机用户2502892403
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有